{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.functional as F\n",
    "import torch.optim as optim\n",
    "from torchvision import datasets,transforms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "epochs = 10\n",
    "batch_size = 200\n",
    "learning_rate = 0.001"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "train_loader = torch.utils.data.DataLoader(\n",
    "    datasets.MNIST('../data',train=True,\n",
    "                    transform=transforms.Compose([transforms.ToTensor(),\n",
    "                    transforms.Normalize((0.1307,),(0.3081,))])),\n",
    "    batch_size = batch_size,shuffle = True)\n",
    "test_loader = torch.utils.data.DataLoader(\n",
    "    datasets.MNIST('../data',train=False,\n",
    "                    transform=transforms.Compose([transforms.ToTensor(),\n",
    "                    transforms.Normalize((0.1307,),(0.3081,))])),\n",
    "    batch_size = batch_size,shuffle = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch.utils.data.dataloader.DataLoader at 0x10e775550>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class MLP(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(MLP,self).__init__()\n",
    "        self.model = nn.Sequential(\n",
    "            nn.Linear(784,200),\n",
    "            nn.ReLU(inplace = True),\n",
    "            nn.Linear(200,200),\n",
    "            nn.ReLU(inplace = True),\n",
    "            nn.Linear(200,10),\n",
    "            nn.ReLU(inplace=True)\n",
    "        )\n",
    "    \n",
    "    def forward(self,x):\n",
    "        x = self.model(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train epoch: 0 [0/60000 (0%)] \t Loss: 2.302549\n",
      "Train epoch: 0 [20000/60000 (33%)] \t Loss: 0.972567\n",
      "Train epoch: 0 [40000/60000 (67%)] \t Loss: 0.975484\n",
      "\n",
      "Test set: Average loss: 0.0049, Accuracy: 6029/10000 (60%)\n",
      "\n",
      "Train epoch: 1 [0/60000 (0%)] \t Loss: 1.003026\n",
      "Train epoch: 1 [20000/60000 (33%)] \t Loss: 0.995098\n",
      "Train epoch: 1 [40000/60000 (67%)] \t Loss: 1.023734\n",
      "\n",
      "Test set: Average loss: 0.0048, Accuracy: 6064/10000 (60%)\n",
      "\n",
      "Train epoch: 2 [0/60000 (0%)] \t Loss: 0.948649\n",
      "Train epoch: 2 [20000/60000 (33%)] \t Loss: 1.025849\n",
      "Train epoch: 2 [40000/60000 (67%)] \t Loss: 0.923469\n",
      "\n",
      "Test set: Average loss: 0.0047, Accuracy: 6076/10000 (60%)\n",
      "\n",
      "Train epoch: 3 [0/60000 (0%)] \t Loss: 0.822239\n",
      "Train epoch: 3 [20000/60000 (33%)] \t Loss: 1.007290\n",
      "Train epoch: 3 [40000/60000 (67%)] \t Loss: 1.047988\n",
      "\n",
      "Test set: Average loss: 0.0047, Accuracy: 6061/10000 (60%)\n",
      "\n",
      "Train epoch: 4 [0/60000 (0%)] \t Loss: 0.988568\n",
      "Train epoch: 4 [20000/60000 (33%)] \t Loss: 0.917506\n",
      "Train epoch: 4 [40000/60000 (67%)] \t Loss: 1.051584\n",
      "\n",
      "Test set: Average loss: 0.0047, Accuracy: 6077/10000 (60%)\n",
      "\n",
      "Train epoch: 5 [0/60000 (0%)] \t Loss: 0.912386\n",
      "Train epoch: 5 [20000/60000 (33%)] \t Loss: 0.934699\n",
      "Train epoch: 5 [40000/60000 (67%)] \t Loss: 1.022685\n",
      "\n",
      "Test set: Average loss: 0.0047, Accuracy: 6061/10000 (60%)\n",
      "\n",
      "Train epoch: 6 [0/60000 (0%)] \t Loss: 1.007931\n",
      "Train epoch: 6 [20000/60000 (33%)] \t Loss: 1.010795\n",
      "Train epoch: 6 [40000/60000 (67%)] \t Loss: 0.795508\n",
      "\n",
      "Test set: Average loss: 0.0047, Accuracy: 6075/10000 (60%)\n",
      "\n",
      "Train epoch: 7 [0/60000 (0%)] \t Loss: 0.890678\n",
      "Train epoch: 7 [20000/60000 (33%)] \t Loss: 0.864746\n",
      "Train epoch: 7 [40000/60000 (67%)] \t Loss: 0.787719\n",
      "\n",
      "Test set: Average loss: 0.0047, Accuracy: 6092/10000 (60%)\n",
      "\n",
      "Train epoch: 8 [0/60000 (0%)] \t Loss: 1.020410\n",
      "Train epoch: 8 [20000/60000 (33%)] \t Loss: 0.822156\n",
      "Train epoch: 8 [40000/60000 (67%)] \t Loss: 1.032961\n",
      "\n",
      "Test set: Average loss: 0.0047, Accuracy: 6093/10000 (60%)\n",
      "\n",
      "Train epoch: 9 [0/60000 (0%)] \t Loss: 0.851646\n",
      "Train epoch: 9 [20000/60000 (33%)] \t Loss: 0.977710\n",
      "Train epoch: 9 [40000/60000 (67%)] \t Loss: 0.971712\n",
      "\n",
      "Test set: Average loss: 0.0048, Accuracy: 6087/10000 (60%)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "model = MLP()\n",
    "optimizer = optim.Adam(model.parameters(),lr=learning_rate)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    for idx,(train,target) in enumerate(train_loader):\n",
    "        train = train.view(-1,28*28)\n",
    "        logits = model(train)\n",
    "        loss = criterion(logits,target)\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    \n",
    "        if idx%100 == 0:\n",
    "            print('Train epoch: {} [{}/{} ({:.0f}%)] \\t Loss: {:.6f}'\n",
    "                      .format(epoch,idx*len(train),len(train_loader.dataset),100. * idx / len(train_loader),loss.item()))\n",
    "    test_loss = 0\n",
    "    correct = 0\n",
    "    for data,tar in test_loader:\n",
    "        data = data.view(-1,28*28)\n",
    "        logs = model(data)\n",
    "        test_loss += criterion(logs,tar).item()\n",
    "        pre_data = logs.data.max(1)[1]\n",
    "        correct += pre_data.eq(tar).sum()\n",
    "    test_loss /= len(test_loader.dataset)\n",
    "    print('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n",
    "        test_loss, correct, len(test_loader.dataset),\n",
    "        100. * correct / len(test_loader.dataset)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
